# sgemm_common_128x32

# Copyright 2014 Nervana Systems Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


--:-:1:-:1      LDS.U.128 j0Ay0, [readAs + 4x<0*128 + 00 + 0*8>];
--:-:1:-:1      LDS.U.128 j0Bx0, [readBs + 4x<0*32  + 00 + 0*8>];
--:-:1:-:1      LDS.U.128 j0Ay4, [readAs + 4x<0*128 + 64 + 0*8>];
--:-:2:-:1      LDS.U.128 j1Ay0, [readAs + 4x<1*128 + 00 + 0*8>];
--:-:2:-:1      LDS.U.128 j1Bx0, [readBs + 4x<1*32  + 00 + 0*8>];
--:-:2:-:1      LDS.U.128 j1Ay4, [readAs + 4x<1*128 + 64 + 0*8>];

LOOP:

<CODE>

    our @top;
    our %insert;
    our $shiftAX;
    our $shiftBX;

    my @cOrder;
    my @swirl = ([0,2],[1,2],[1,0],[0,0]);
    my @y = (0,1,4,5);
    foreach my $x (0,2)
    {
        foreach my $y (@y)
        {
            push @cOrder, [$x + $_->[0], $y + $_->[1]] foreach @swirl;
        }
        @y = reverse @y;
    }

    my $out = join '', @top;

    foreach my $j (0 .. 15)
    {
        my $barrier   = $j & 1 ? 2 : 1;
        my $rsPred    = $j >= 14 ? '@P0' : '   ';
        my $loadReg   = ($j + 2) & 3;
        my $shareLine = ($j + 2) & 15;
        my $shiftA    = $shiftAX ? $shareLine >> 2 : 0;
        my $shiftB    = $shiftBX ? $shareLine >> 2 : 0;
        my $compute   = $j & 3;


        $insert{"j${j}c0"} = sprintf "--:-:%d:-:1  %s LDS.U.128 j%dAy0, [readAs + 4x<%d*128 + 00 + %d*8>];\n", $barrier, $rsPred, $loadReg, $shareLine, $shiftA;
        $insert{"j${j}c2"} = sprintf "--:-:%d:-:1  %s LDS.U.128 j%dBx0, [readBs + 4x<%d*32  + 00 + %d*8>];\n", $barrier, $rsPred, $loadReg, $shareLine, $shiftB;
        $insert{"j${j}c4"} = sprintf "--:-:%d:-:1  %s LDS.U.128 j%dAy4, [readAs + 4x<%d*128 + 64 + %d*8>];\n", $barrier, $rsPred, $loadReg, $shareLine, $shiftA;

        foreach my $c (0 .. 31)
        {
            my ($x,$y) = @{$cOrder[$c]};

            my $ins    = $insert{"j${j}c$c"} || '';

            my $wait   = $c == 0 ? "0$barrier" : '--';

            my $stall  = (split "\n", $ins)[0] =~ /LDS|F2F|I2I|LDG|STS|BAR|BRA/ ? 0 : 1;

            my $yield  = $c == 16 && $stall ? 'Y' : '-';

            my $ctrl   = "$wait:-:-:$yield:$stall";

            $out .= sprintf "%s      FFMA cx%dy%d, j%dBx%d, j%dAy%d, cx%dy%d;\n%s", $ctrl,  $x,$y,  $compute,$x,  $compute,$y,  $x,$y,  $ins;
        }
    }
    return $out;

</CODE>

<SCHEDULE_BLOCK>

--:-:-:-:1      MOV alpha, param_alpha;
--:-:-:-:1      MOV beta,  param_beta;
--:-:-:-:1      MOV flags, param_flags;
--:-:-:-:5      MOV xcutoff, param_xcutoff;

--:-:-:-:6      LOP.AND.NZ   P0, RZ, flags, 4;
--:-:-:-:6  @P0 IADD offsetC, -time_step, param_unrolling;
--:-:-:-:6  @P0 IADD offsetC, offsetC,    -1;
--:-:-:-:6 @!P0 MOV  offsetC, time_step;

// baseH = param_H + dimH * time_step
--:-:-:-:1      XMAD     offsetH,   offsetC,   param_dimH, RZ;
--:-:-:-:1      LEA      baseH0.CC, offsetH,   param_H[0],     2;
--:-:-:-:1      LEA.HI.X baseH1,    offsetH,   param_H[1], RZ, 2;

// baseC = param_C + dimC * time_step
--:-:-:-:1      XMAD     offsetC,   offsetC,   param_dimC, RZ;
--:-:-:-:1      LEA      baseC0.CC, offsetC,   param_C[0],     2;
--:-:-:-:1      LEA.HI.X baseC1,    offsetC,   param_C[1], RZ, 2;

// writeCs = (readAs / 4) * 32 + readBs;
--:-:-:-:1      ISETP.GT.AND P0, PT, swapBuf, RZ, PT;
--:-:-:-:1      IADD readBs,  readBs, -4x<szShareA>;
--:-:-:-:1  @P0 IADD readAs,  readAs, -swapBuf;
--:-:-:-:1  @P0 IADD readBs,  readBs, -swapBuf;
--:-:-:-:1      ISCADD  writeCs, readAs, readBs, 3;

// readCs = ((tid & 96) << 2) | (tid & 31)   << 2;
--:-:-:-:1      LOP.AND tid31, tid, 31;
--:-:-:-:1      LOP.AND tid96, tid, 96;
--:-:-:-:1      ISCADD readCs, tid96, tid31, 2;
--:-:-:-:1      SHL    readCs, readCs, 2;

// cx = blkB*32 + tid31;
--:-:-:-:1      ISCADD cx, blkB, tid31, 5;

// cy = blkA*128 + (tid96 >> 1)
--:-:-:-:1      SHR.U32 cy00, tid96, 1;
--:-:-:-:1      ISCADD  cy00, blkA, cy00, 7;

// C += (cy*ldc + cx) * 4;
// C += (ldcz*blockZ + ldc*cy + cx00) * 4;
--:-:-:-:1      MOV  ldc,  param_ldc;
--:-:-:-:1      MOV  ldcz, param_ldcz;
--:-:-:-:1      XMAD.LO  ci, ldc,  cy00, cx, xmad_c;
--:-:-:-:1      XMAD.LO2 ci, ldcz, RZ, ci;
--:-:-:-:1      LEA      C00y0.CC, ci, baseC0,     2;
--:-:-:-:1      LEA.HI.X C00y1,    ci, baseC1, RZ, 2;

// Apply relu
--:-:-:-:0      LOP.AND.NZ   P4, RZ, flags, 2;
// cx < n
--:-:-:-:1      ISETP.LT.AND P6, PT, cx, param_n, PT;
// beta != 0
--:-:-:-:1      ISETP.NE.AND P5, PT, beta, RZ, P6;


--:-:-:-:1      SHL  ldc1, ldc, 2;
--:-:-:-:1      SHL  ldc4, ldc, 4;
--:-:-:-:1      ISCADD ldc60, ldc, -ldc4, 8;

--:-:-:-:1      MOV  ldh1, param_ldh;

// H += (ldh*cy + cx) * 4
--:-:-:-:1      XMAD.LO  ci, ldh1,  cy00, cx, xmad_c;
--:-:-:-:1      LEA      H00y0.CC, ci, baseH0,     2;
--:-:-:-:1      LEA.HI.X H00y1,    ci, baseH1, RZ, 2;

--:-:-:-:1      SHL  ldh1, ldh1, 2;
--:-:-:-:1      SHL  ldh4, ldh1, 2;
--:-:-:-:1      SHL  ldh60, ldh1, 6;
--:-:-:-:1      IADD ldh60, ldh60, -ldh4;
</SCHEDULE_BLOCK>

--:-:-:-:4      IADD   C04y0.CC, C00y0, ldc4;
--:-:-:-:1      MOV d0, RZ;
--:-:-:-:1      IADD   cy04, cy00,  4;
--:-:-:-:1      IADD.X C04y1,    C00y1, RZ;
--:-:-:-:4      IADD   C08y0.CC, C04y0, ldc4;
--:-:-:-:1      MOV d1, RZ;
--:-:-:-:1      IADD   cy08, cy00,  8;
--:-:-:-:1      IADD.X C08y1,    C04y1, RZ;
--:-:-:-:3      IADD   C12y0.CC, C08y0, ldc4;
--:-:-:-:1      MOV d2, RZ;
--:-:-:-:1      MOV d3, RZ;
--:-:-:-:1      IADD   cy12, cy00,  12;
--:-:-:-:1      IADD.X C12y1,    C08y1, RZ;

--:-:-:-:6      IADD   H04y0.CC, H00y0, ldh4;
--:-:-:-:1      IADD.X H04y1,    H00y1, RZ;
--:-:-:-:6      IADD   H08y0.CC, H04y0, ldh4;
--:-:-:-:1      IADD.X H08y1,    H04y1, RZ;
--:-:-:-:6      IADD   H12y0.CC, H08y0, ldh4;
--:-:-:-:0      IADD.X H12y1,    H08y1, RZ;

--:-:-:-:5      BAR.SYNC 0;

<CODE>

    my $out;
    foreach my $y (0..7)
    {
        $out .=
            "--:-:-:-:5      IADD   C00y0.CC, C00y0, ldc60;\n" .
            "--:-:-:-:1      IADD   cy00,     cy00,  60;\n" .
            "--:-:-:-:1      IADD.X C00y1,    C00y1, RZ;\n" .
            "--:-:-:-:5      IADD   C04y0.CC, C04y0, ldc60;\n" .
            "--:-:-:-:1      IADD   cy04,     cy04,  60;\n" .
            "--:-:-:-:1      IADD.X C04y1,    C04y1, RZ;\n" .
            "--:-:-:-:5      IADD   C08y0.CC, C08y0, ldc60;\n" .
            "--:-:-:-:1      IADD   cy08,     cy08,  60;\n" .
            "--:-:-:-:1      IADD.X C08y1,    C08y1, RZ;\n" .
            "--:-:-:-:5      IADD   C12y0.CC, C12y0, ldc60;\n" .
            "--:-:-:-:1      IADD   cy12,     cy12,  60;\n" .
            "--:-:-:-:1      IADD.X C12y1,    C12y1, RZ;\n\n" .
            "--:-:-:-:6      IADD   H00y0.CC, H00y0, ldh60;\n" .
            "--:-:-:-:1      IADD.X H00y1,    H00y1, RZ;\n" .
            "--:-:-:-:6      IADD   H04y0.CC, H04y0, ldh60;\n" .
            "--:-:-:-:1      IADD.X H04y1,    H04y1, RZ;\n" .
            "--:-:-:-:6      IADD   H08y0.CC, H08y0, ldh60;\n" .
            "--:-:-:-:1      IADD.X H08y1,    H08y1, RZ;\n" .
            "--:-:-:-:6      IADD   H12y0.CC, H12y0, ldh60;\n" .
            "--:-:-:-:1      IADD.X H12y1,    H12y1, RZ;\n" if $y == 4;

        $out .= sprintf(
            "--:-:-:-:1      FMUL c0, cx0y%d, alpha;\n" .
            "--:-:-:-:1      FMUL c1, cx1y%d, alpha;\n" .
            "--:-:-:-:1      FMUL c2, cx2y%d, alpha;\n" .
            "--:-:-:-:0      FMUL c3, cx3y%d, alpha;\n",
            ($y) x 4);

        $out .= "--:-:-:-:5      CAL STORE_C;\n\n";
    }
    return $out;

</CODE>

--:-:-:-:1      MOV lockAddr0, param_lockAddr[0];
--:-:-:-:1      MOV lockAddr1, param_lockAddr[1];

// time_step = time_step + 1
--:-:-:-:6      IADD time_step, time_step, 1;
--:-:-:-:1      ISETP.LT.AND P0, PT, time_step, param_unrolling, PT;

// Synchronize all blocks
--:-:-:-:1      ISETP.NE.AND P1, PT, tid, RZ, PT;
--:-:-:-:6      XMAD blkId, blkB, param_numAblks, blkA;
--:-:-:-:6      IADD nextBlk, blkId, 1;
--:-:-:-:8      ISETP.EQ.OR P2, PT, nextBlk, param_numBlks, P1;

--:-:-:-:5      BAR.SYNC 0;

--:-:-:-:1      SSY SSY_TARGET1;
--:-:-:-:d  @P1 SYNC;
--:-:-:-:6  @P2 MOV nextBlk, RZ;

SPINLOCK1:
--:-:1:Y:2      ATOM.E.CAS lockVal, [lockAddr], blkId, nextBlk;
01:-:-:Y:d      ISETP.NE.AND P1, PT, lockVal, blkId, PT;
--:-:-:-:d  @P1 BRA.U SPINLOCK1;
--:-:-:-:d      SYNC;

SSY_TARGET1:
--:-:-:-:1      SSY SSY_TARGET2;
--:-:-:-:d  @P2 SYNC;
--:-:-:-:6      MOV nextBlk, RZ;

SPINLOCK2:
--:-:1:Y:2      ATOM.E.CAS lockVal, [lockAddr], blkId, nextBlk;
01:-:-:Y:d      ISETP.NE.AND P1, PT, lockVal, RZ, PT;
--:-:-:-:5  @P1 BRA.U SPINLOCK2;
--:-:-:-:d      SYNC;

SSY_TARGET2:
--:-:-:-:5      BAR.SYNC 0;
--:-:-:-:f      MEMBAR.GL;

//Loop back to beginning of GEMM loop
--:-:-:Y:5  @P0 BRA.U RNN_LOOP;

--:-:-:-:5      EXIT;

STORE_C:

<SCHEDULE_BLOCK>
--:-:-:-:1      LDG.E h0, [H00y];
--:-:-:-:1      LDG.E h1, [H04y];
--:-:-:-:1      LDG.E h2, [H08y];
--:-:-:-:1      LDG.E h3, [H12y];
</SCHEDULE_BLOCK>

<SCHEDULE_BLOCK>
--:-:-:-:1      ISETP.LT.AND P0, PT, cy00, param_m, P5;
--:-:-:-:1      ISETP.LT.AND P1, PT, cy04, param_m, P5;
--:-:-:-:1      ISETP.LT.AND P2, PT, cy08, param_m, P5;
--:-:-:-:1      ISETP.LT.AND P3, PT, cy12, param_m, P5;

--:-:1:-:1  @P0 LDG.E d0, [C00y];
--:-:2:-:1  @P1 LDG.E d1, [C04y];
--:-:3:-:1  @P2 LDG.E d2, [C08y];
--:-:4:-:1  @P3 LDG.E d3, [C12y];
--:-:-:-:1 @!P0 MOV d0, RZ;
--:-:-:-:1 @!P1 MOV d1, RZ;
--:-:-:-:1 @!P2 MOV d2, RZ;
--:-:-:-:1 @!P3 MOV d3, RZ;

--:-:-:-:1      ISETP.LT.AND P0, PT, cy00, param_m, P6;
--:-:-:-:1      ISETP.LT.AND P1, PT, cy04, param_m, P6;
--:-:-:-:1      ISETP.LT.AND P2, PT, cy08, param_m, P6;
--:-:-:-:1      ISETP.LT.AND P3, PT, cy12, param_m, P6;

--:-:-:-:1      IADD cy00, cy00, 1;
--:-:-:-:1      IADD cy04, cy04, 1;
--:-:-:-:1      IADD cy08, cy08, 1;
--:-:-:-:3      IADD cy12, cy12, 1;

--:-:-:-:1  @P4 FMNMX c0, c0, RZ, !PT;
--:-:-:-:1  @P4 FMNMX c1, c1, RZ, !PT;
--:-:-:-:1  @P4 FMNMX c2, c2, RZ, !PT;
--:-:-:-:1  @P4 FMNMX c3, c3, RZ, !PT;

--:-:-:-:1      STS.128 [writeCs], c0;
--:-:-:-:1      LDS c0, [readCs + 4x<0*32>];
--:-:5:-:1      LDS c1, [readCs + 4x<1*32>];
--:-:-:-:1      LDS c2, [readCs + 4x<2*32>];
--:-:6:-:1      LDS c3, [readCs + 4x<3*32>];
</SCHEDULE_BLOCK>

--:-:-:-:1  P2R predSave, PR, RZ, 0x0f;

11:-:-:-:1  @P5 FFMA c0, d0, beta, c0;
02:-:-:-:1  @P5 FFMA c1, d1, beta, c1;
24:-:-:-:1  @P5 FFMA c2, d2, beta, c2;
08:-:-:-:3  @P5 FFMA c3, d3, beta, c3;

//Bprop for activation: Rectlinclip
<SCHEDULE_BLOCK>
--:-:-:-:1  FSETP.LT.AND P0, PT, RZ, h0, PT;
--:-:-:-:1  FSETP.LT.AND P1, PT, RZ, h1, PT;
--:-:-:-:1  FSETP.LT.AND P2, PT, RZ, h2, PT;
--:-:-:-:1  FSETP.LT.AND P3, PT, RZ, h3, PT;
--:-:-:-:1  FSETP.LT.AND P0, PT, h0, xcutoff, P0;
--:-:-:-:1  FSETP.LT.AND P1, PT, h1, xcutoff, P1;
--:-:-:-:1  FSETP.LT.AND P2, PT, h2, xcutoff, P2;
--:-:-:-:1  FSETP.LT.AND P3, PT, h3, xcutoff, P3;
--:-:-:-:1  SEL c0, c0, RZ, P0;
--:-:-:-:1  SEL c1, c1, RZ, P1;
--:-:-:-:1  SEL c2, c2, RZ, P2;
--:-:-:-:1  SEL c3, c3, RZ, P3;
</SCHEDULE_BLOCK>

--:-:-:Y:d  R2P PR, predSave, 0x0f;

--:1:-:-:1  @P0 STG.E [C00y], c0;
--:2:-:-:1  @P1 STG.E [C04y], c1;
--:3:-:-:1  @P2 STG.E [C08y], c2;
--:4:-:-:1  @P3 STG.E [C12y], c3;

01:-:-:-:6      IADD   C00y0.CC, C00y0, ldc1;
--:-:-:-:1      IADD.X C00y1,    C00y1, RZ;
02:-:-:-:6      IADD   C04y0.CC, C04y0, ldc1;
--:-:-:-:1      IADD.X C04y1,    C04y1, RZ;
04:-:-:-:6      IADD   C08y0.CC, C08y0, ldc1;
--:-:-:-:1      IADD.X C08y1,    C08y1, RZ;
08:-:-:-:6      IADD   C12y0.CC, C12y0, ldc1;
--:-:-:-:1      IADD.X C12y1,    C12y1, RZ;

--:-:-:-:6      IADD   H00y0.CC, H00y0, ldh1;
--:-:-:-:1      IADD.X H00y1,    H00y1, RZ;
--:-:-:-:6      IADD   H04y0.CC, H04y0, ldh1;
--:-:-:-:1      IADD.X H04y1,    H04y1, RZ;
--:-:-:-:6      IADD   H08y0.CC, H08y0, ldh1;
--:-:-:-:1      IADD.X H08y1,    H08y1, RZ;
--:-:-:-:6      IADD   H12y0.CC, H12y0, ldh1;
--:-:-:-:0      IADD.X H12y1,    H12y1, RZ;

--:-:-:-:5      RET;
